import base64
import io
import json
import re
from typing import Any, Sequence

import filetype
from lxml import etree as LET  # nosec: B410
from pydantic import BaseModel, ConfigDict, field_serializer, model_validator

from dynamiq.nodes.agents.exceptions import JSONParsingError, ParsingError, TagNotFoundError, XMLParsingError
from dynamiq.prompts import (
    Message,
    MessageRole,
    VisionMessage,
    VisionMessageImageContent,
    VisionMessageImageURL,
    VisionMessageTextContent,
)
from dynamiq.storages.file.base import FileInfo
from dynamiq.utils.logger import logger
from dynamiq.utils.utils import CHARS_PER_TOKEN

TOOL_MAX_TOKENS = 64000


def convert_bytesio_to_file_info(bytesio_obj: io.BytesIO, key: str, index: int = None) -> FileInfo:
    """Convert a BytesIO object to a FileInfo object with base64 encoded content."""
    content_bytes = bytesio_obj.getvalue()

    encoded = base64.b64encode(content_bytes).decode("utf-8")

    name = getattr(bytesio_obj, "name", f"file_{key}" if index is None else f"file_{key}_{index}")
    content_type = getattr(bytesio_obj, "content_type", "unknown")
    description = getattr(bytesio_obj, "description", "")

    # Create a path based on the name
    path = f"/{name}" if not name.startswith("/") else name

    return FileInfo(
        content=encoded,
        path=path,
        name=name,
        content_type=content_type,
        metadata={"description": description},
        size=len(content_bytes),
    )


class FileMappedInput(BaseModel):
    """Structure for storing file mapped inputs."""

    input: Any
    files: list[io.BytesIO]  # List of BytesIO objects or FileInfo objects
    model_config = ConfigDict(arbitrary_types_allowed=True)

    @field_serializer("files")
    def serialize_files(self, files: list[io.BytesIO]) -> list[str]:
        return [getattr(file, "name", f"file_{i}") for i, file in enumerate(files)]


class XMLParser:
    """
    Utility class for parsing XML-like output, often generated by LLMs.
    Prioritizes lxml for robustness, with fallbacks for common issues.
    """

    DEFAULT_PRESERVE_TAGS = ["answer", "thought"]

    @staticmethod
    def _clean_content(text: str) -> str:
        """
        Cleans the input string to remove common LLM artifacts and isolate XML.

        Args:
            text (str): The input string to be cleaned

        Returns:
            str: Cleaned string containing only the XML content
        """
        if not isinstance(text, str):
            return ""

        cleaned = text.strip()

        if cleaned.startswith("```") and cleaned.endswith("```"):
            cleaned = re.sub(r"^```.*?\n", "", cleaned)
            cleaned = re.sub(r"\n```$", "", cleaned)

        xml_matches = list(re.finditer(r"<(\w+)\b[^>]*>.*?</\1>", cleaned, re.DOTALL))
        if xml_matches:
            for match in reversed(xml_matches):
                candidate = match.group(0)
                if "<answer" in candidate:
                    cleaned = candidate
                    break
            else:
                cleaned = xml_matches[-1].group(0)

        return cleaned

    @staticmethod
    def extract_content_with_regex_fallback(text: str, tag_name: str) -> str | None:
        """
        Extract tag content using regex as a fallback when XML parsing fails.
        Works with both complete and incomplete tags.

        Args:
            text (str): The XML-like text to extract content from
            tag_name (str): The name of the tag to extract content from

        Returns:
            str | None: The extracted content if found, None otherwise
        """
        complete_pattern = f"<{tag_name}[^>]*>(.*?)</{tag_name}>"
        complete_match = re.search(complete_pattern, text, re.DOTALL)

        if complete_match:
            content = complete_match.group(1).strip()
            if content:
                return content
            return None

        incomplete_pattern = f"<{tag_name}[^>]*>(.*?)(?=<(?!/{tag_name})|$)"
        incomplete_match = re.search(incomplete_pattern, text, re.DOTALL)

        if incomplete_match:
            content = incomplete_match.group(1).strip()
            if content:
                return content

        return None

    @staticmethod
    def preprocess_xml_content(
        text: str, required_tags: Sequence[str] = None, optional_tags: Sequence[str] = None
    ) -> dict[str, tuple[str, str]]:
        """
        Extract raw content for all tags that need special handling before parsing.
        Filters tags based on required and optional tags if provided.

        Args:
            text (str): The XML-like text to preprocess
            required_tags (Sequence[str], optional): List of tags that must be present
            optional_tags (Sequence[str], optional): List of tags that may be present

        Returns:
            dict[str, tuple[str, str]]: Dictionary mapping tag names to (modified_text, raw_content) tuples
        """
        extracted_contents = {}
        tags_to_process = set(XMLParser.DEFAULT_PRESERVE_TAGS)

        # Add required and optional tags to the set if provided
        if required_tags:
            tags_to_process.update(required_tags)
        if optional_tags:
            tags_to_process.update(optional_tags)

        text = XMLParser._escape_unbalanced_reserved_tags(text, tags_to_process)

        for tag in tags_to_process:
            content = XMLParser.extract_content_with_regex_fallback(text, tag)
            if content:
                tag_pattern = f"<{tag}[^>]*>.*?(?=<(?!/{tag})|$)|<{tag}[^>]*>.*?</{tag}>"
                modified_text = re.sub(tag_pattern, f"<{tag}>CONTENT_PLACEHOLDER_{tag}</{tag}>", text, flags=re.DOTALL)
                extracted_contents[tag] = (modified_text, content)
                text = modified_text

        return extracted_contents

    @staticmethod
    def _escape_unbalanced_reserved_tags(text: str, tags: Sequence[str]) -> str:
        """Escape reserved tag tokens that appear in prose instead of well-formed XML."""

        if not text:
            return text

        for tag in tags:
            open_tag = f"<{tag}>"
            close_tag = f"</{tag}>"

            # Replace stray closing tags without matching opening tags
            search_start = 0
            while True:
                close_index = text.find(close_tag, search_start)
                if close_index == -1:
                    break

                opening_index = text.rfind(open_tag, 0, close_index)
                if opening_index == -1:
                    text = text[:close_index] + f"&lt;/{tag}&gt;" + text[close_index + len(close_tag) :]
                    search_start = close_index + len(f"&lt;/{tag}&gt;")
                else:
                    search_start = close_index + len(close_tag)

            # Replace stray opening tags without matching closing tags
            search_start = 0
            while True:
                open_index = text.find(open_tag, search_start)
                if open_index == -1:
                    break

                closing_index = text.find(close_tag, open_index + len(open_tag))
                if closing_index == -1:
                    text = text[:open_index] + f"&lt;{tag}&gt;" + text[open_index + len(open_tag) :]
                    search_start = open_index + len(f"&lt;{tag}&gt;")
                else:
                    search_start = open_index + len(open_tag)

        return text

    @staticmethod
    def _parse_with_lxml(cleaned_text: str) -> LET._Element | None:
        """
        Attempts to parse the cleaned text using lxml with recovery.

        Args:
            cleaned_text (str): The cleaned XML text to parse

        Returns:
            LET._Element | None: The parsed XML element if successful, None otherwise
        """
        if not cleaned_text:
            return None
        try:
            tags_to_check = ["thought", "answer", "action", "action_input", "output"]
            fixed_text = cleaned_text

            for tag in tags_to_check:
                opening_count = len(re.findall(f"<{tag}[^>]*>", fixed_text))
                closing_count = len(re.findall(f"</{tag}>", fixed_text))

                if opening_count > closing_count:
                    logger.debug(f"XMLParser: Adding missing </{tag}> tags")
                    if tag == "output":
                        fixed_text += f"</{tag}>"
                    else:
                        pos = fixed_text.find(f"<{tag}")
                        if pos >= 0:
                            next_tag_pos = fixed_text.find("<", pos + 1)
                            if next_tag_pos > 0 and f"</{tag}>" not in fixed_text[pos:next_tag_pos]:
                                fixed_text = fixed_text[:next_tag_pos] + f"</{tag}>" + fixed_text[next_tag_pos:]

            parser = LET.XMLParser(recover=True, encoding="utf-8")
            root = LET.fromstring(fixed_text.encode("utf-8"), parser=parser)  # nosec: B320
            return root
        except LET.XMLSyntaxError as e:
            logger.warning(f"XMLParser: lxml parsing failed with recovery: {e}. Content: {cleaned_text[:200]}...")
            return None
        except Exception as e:
            logger.error(f"XMLParser: Unexpected error during parsing: {e}. Content: {cleaned_text[:200]}...")
            return None

    @staticmethod
    def _extract_data_lxml(
        root: LET._Element,
        required_tags: Sequence[str],
        optional_tags: Sequence[str] = None,
        preserve_format_tags: Sequence[str] = None,
    ) -> dict[str, str]:
        """
        Extracts text content from specified tags using XPath.

        Args:
            root (LET._Element): The root XML element to extract data from
            required_tags (Sequence[str]): Tags that must be present in the output
            optional_tags (Sequence[str], optional): Tags to extract if present
            preserve_format_tags (Sequence[str], optional): Tags where original formatting should be preserved

        Returns:
            dict[str, str]: Dictionary mapping tag names to their extracted content

        Raises:
            TagNotFoundError: If a required tag is missing or empty
        """
        data = {}
        optional_tags = optional_tags or []
        preserve_format_tags = list(preserve_format_tags or [])

        for tag in XMLParser.DEFAULT_PRESERVE_TAGS:
            if tag not in preserve_format_tags:
                preserve_format_tags.append(tag)

        all_tags = list(required_tags) + list(optional_tags)

        for tag in all_tags:
            tag_content = None
            element_found = False
            elements = root.xpath(f".//{tag}")
            if elements:
                element_found = True
                for elem in elements:
                    if tag in preserve_format_tags:
                        xml_content = LET.tostring(elem, encoding="unicode", method="xml")
                        tag_pattern = re.compile(f"<{tag}[^>]*>(.*?)</{tag}>", re.DOTALL)
                        match = tag_pattern.search(xml_content)
                        text = match.group(1) if match else ""
                    else:
                        text = "".join(elem.itertext()).strip()

                    if not text and tag in required_tags:
                        raise TagNotFoundError(f"Required tag <{tag}> found but contains no text content.")

                    if text:
                        if text.startswith("CONTENT_PLACEHOLDER_"):
                            continue
                        tag_content = text
                        break

            if not element_found:
                try:
                    if root.getparent() is not None:
                        parent_elements = root.xpath(f"../{tag}")
                        if parent_elements:
                            element_found = True
                            for elem in parent_elements:
                                text_parent_child = "".join(elem.itertext()).strip()
                                if text_parent_child:
                                    tag_content = text_parent_child
                                    break
                except (AttributeError, Exception) as e:
                    logger.debug(f"XMLParser: Error checking parent for tag '{tag}': {e}")
                    pass

            if tag_content is not None:
                data[tag] = tag_content
            elif element_found and tag in required_tags and tag_content is None:
                raise TagNotFoundError(f"Required tag <{tag}> found but contains no text content.")
            elif not element_found and tag in required_tags:
                raise TagNotFoundError(
                    f"Required tag <{tag}> not found in the XML structure "
                    f"relative to the parsed root element ('{root.tag}') or its parent."
                )

        missing_required_after_all = [tag for tag in required_tags if tag not in data]
        if missing_required_after_all:
            raise TagNotFoundError(f"Required tags missing after extraction: {', '.join(missing_required_after_all)}")

        return data

    @staticmethod
    def _parse_json_fields(data: dict[str, str], json_fields: Sequence[str]) -> dict[str, Any]:
        """
        Parses specified fields in the data dictionary as JSON.

        Args:
            data (dict[str, str]): Dictionary of extracted tag contents
            json_fields (Sequence[str]): List of field names to parse as JSON

        Returns:
            dict[str, Any]: Dictionary with specified fields parsed as JSON objects

        Raises:
            JSONParsingError: If a JSON field cannot be parsed correctly
        """
        parsed_data = data.copy()
        for field in json_fields:
            if field in parsed_data:
                try:
                    json_string = re.sub(r"^```(?:json)?\s*|```$", "", parsed_data[field].strip())
                    parsed_data[field] = json.loads(json_string)
                except json.JSONDecodeError as e:
                    error_message = (
                        f"Failed to parse JSON content for field '{field}'. "
                        f"Error: {e}. Original content: '{parsed_data[field][:100]}...'"
                    )
                    guidance = (
                        " Ensure the value is valid JSON with double quotes for keys and strings, "
                        'and proper escaping for special characters (e.g., \\n for newlines, \\" for quotes).'
                    )
                    raise JSONParsingError(error_message + guidance)
                except Exception as e:
                    raise JSONParsingError(f"Unexpected error parsing JSON for field '{field}': {e}")
        return parsed_data

    @staticmethod
    def parse(
        text: str,
        required_tags: Sequence[str],
        optional_tags: Sequence[str] = None,
        json_fields: Sequence[str] = None,
        preserve_format_tags: Sequence[str] = None,
        attempt_wrap: bool = True,
    ) -> dict[str, Any]:
        """
        Parse XML-like text to extract structured data from specified tags.

        This function employs a multi-stage parsing strategy:
        1. First cleans and preprocesses the input text
        2. Attempts extraction using regex for tags needing special handling
        3. Tries robust XML parsing with lxml (with auto-repair capabilities)
        4. Falls back to regex extraction for any remaining tags
        5. Processes JSON fields if specified

        Each stage has fallback mechanisms to handle various edge cases commonly
        seen in LLM-generated XML, including malformed tags, missing closing tags,
        and inconsistent formatting.

        Args:
            text (str): The XML-like text to parse
            required_tags (Sequence[str]): Tags that must be present in the output
            optional_tags (Sequence[str], optional): Tags to extract if present
            json_fields (Sequence[str], optional): Fields to parse as JSON objects
            preserve_format_tags (Sequence[str], optional): Tags where original formatting
                                                           should be preserved
            attempt_wrap (bool, optional): Whether to try wrapping content in a root tag
                                          if initial parsing fails

        Returns:
            dict[str, Any]: Dictionary mapping tag names to their extracted content

        Raises:
            ParsingError: If the input is empty or invalid
            TagNotFoundError: If a required tag is missing or empty
            JSONParsingError: If a JSON field cannot be parsed
        """
        optional_tags = optional_tags or []
        json_fields = json_fields or []
        preserve_format_tags = list(preserve_format_tags or [])

        required_optional_set = set(required_tags) | set(optional_tags)

        for tag in XMLParser.DEFAULT_PRESERVE_TAGS:
            if tag not in preserve_format_tags:
                preserve_format_tags.append(tag)
        cleaned_text = XMLParser._clean_content(text)
        if not cleaned_text:
            if text and text.strip():
                cleaned_text = text.strip()
            else:
                if required_tags:
                    raise ParsingError("Input text is empty or became empty after cleaning.")
                else:
                    return {}

        tags_to_process = set(XMLParser.DEFAULT_PRESERVE_TAGS)
        if required_tags:
            tags_to_process.update(required_tags)
        if optional_tags:
            tags_to_process.update(optional_tags)

        extracted_contents = XMLParser.preprocess_xml_content(cleaned_text, required_tags, optional_tags)

        modified_text = cleaned_text
        if extracted_contents:
            for tag, (tag_modified_text, _) in extracted_contents.items():
                modified_text = tag_modified_text

        result = {}
        for tag, (_, content) in extracted_contents.items():
            if tag in required_optional_set:
                result[tag] = content

        remaining_required = [tag for tag in required_tags if tag not in result]
        remaining_optional = [tag for tag in optional_tags if tag not in result]

        try:
            root = XMLParser._parse_with_lxml(modified_text)

            if root is None and attempt_wrap:
                wrapped_text = f"<root>{modified_text}</root>"
                root = XMLParser._parse_with_lxml(wrapped_text)

            if root is not None and (remaining_required or remaining_optional):
                xml_data = XMLParser._extract_data_lxml(
                    root, remaining_required, remaining_optional, preserve_format_tags
                )
                result.update(xml_data)
                remaining_required = [tag for tag in remaining_required if tag not in result]
                remaining_optional = [tag for tag in remaining_optional if tag not in result]
        except TagNotFoundError as e:
            logger.debug(
                "XMLParser: lxml extraction missing tag (will retry via fallback): %s",
                e,
            )
        except Exception as e:
            logger.warning(f"XMLParser: XML parsing failed: {e}")

        for tag in remaining_required:
            content = XMLParser.extract_content_with_regex_fallback(text, tag)
            if content:
                result[tag] = content
            else:
                empty_tag_pattern = f"<{tag}[^>]*>\\s*</{tag}>"
                if re.search(empty_tag_pattern, text):
                    raise TagNotFoundError(f"Required tag <{tag}> found but contains no text content.")
                else:
                    raise TagNotFoundError(f"Required tag <{tag}> not found even with fallback methods")

        for tag in remaining_optional:
            content = XMLParser.extract_content_with_regex_fallback(text, tag)
            if content:
                result[tag] = content

        if result:
            result = XMLParser._restore_placeholder_mentions(result, tags_to_process)

        if json_fields and result:
            result = XMLParser._parse_json_fields(result, json_fields)

        return result

    @staticmethod
    def _restore_placeholder_mentions(data: dict[str, Any], tags: Sequence[str]) -> dict[str, Any]:
        """Replace placeholder artifacts with escaped tag mentions."""

        restored = {}
        for key, value in data.items():
            if isinstance(value, str):
                for tag in tags:
                    placeholder = f"<{tag}>CONTENT_PLACEHOLDER_{tag}</{tag}>"
                    if placeholder in value:
                        value = value.replace(placeholder, f"&lt;{tag}&gt;")
                    # Escape any remaining raw occurrences of reserved tags
                    value = value.replace(f"<{tag}>", f"&lt;{tag}&gt;")
                    value = value.replace(f"</{tag}>", f"&lt;/{tag}&gt;")
                restored[key] = value
            else:
                restored[key] = value

        return restored

    @staticmethod
    def extract_first_tag_lxml(text: str, tags: Sequence[str]) -> str | None:
        """
        Extracts the text content of the first tag found from the list using lxml.
        Useful for simple cases like extracting just the final answer.

        Args:
            text (str): The XML-like text to extract content from
            tags (Sequence[str]): Ordered list of tags to look for

        Returns:
            str | None: Content of the first found tag, or None if no tags are found
        """
        cleaned_text = XMLParser._clean_content(text)
        if not cleaned_text:
            return None

        root = XMLParser._parse_with_lxml(cleaned_text)

        if root is None:
            wrapped_text = f"<root>{cleaned_text}</root>"
            root = XMLParser._parse_with_lxml(wrapped_text)

        if root is None:
            logger.warning(f"XMLParser: extract_first_tag_lxml failed to parse: {cleaned_text[:200]}...")
            return None

        for tag in tags:
            elements = root.xpath(f".//{tag}")
            if elements:
                for elem in elements:
                    content = "".join(elem.itertext()).strip()
                    if content:
                        return content
        return None

    @staticmethod
    def extract_first_tag_regex(text: str, tags: Sequence[str]) -> str | None:
        """
        Fallback method: Extracts the text content of the first tag found using regex.
        Less reliable than lxml, use only when lxml fails completely.

        Args:
            text (str): The XML-like text to extract content from
            tags (Sequence[str]): Ordered list of tags to look for

        Returns:
            str | None: Content of the first found tag, or None if no tags are found
        """
        if not isinstance(text, str):
            return None

        for tag in tags:
            match = re.search(f"<{tag}\\b[^>]*>(.*?)</{tag}>", text, re.DOTALL | re.IGNORECASE)
            if match:
                content = match.group(1).strip()
                if content:
                    return content
        return None

    @staticmethod
    def parse_unified_xml_format(text: str) -> dict[str, Any]:
        """
        Parse XML content using a unified structure for both tool calls and final answers.

        Args:
            text (str): The XML content to parse

        Returns:
            dict: A dictionary containing parsed data with the following structure:
                For tool calls:
                {
                    "is_final": False,
                    "thought": "extracted thought text",
                    "tools": [
                        {"name": "tool name", "input": parsed_json_input},
                        {"name": "tool name", "input": parsed_json_input},
                        ...
                    ]
                }

                For final answer:
                {
                    "is_final": True,
                    "thought": "extracted thought text",
                    "answer": "final answer text"
                }

        Raises:
            XMLParsingError: If the XML cannot be parsed
            TagNotFoundError: If required tags are missing
        """
        try:
            cleaned_text = XMLParser._clean_content(text)
            if not cleaned_text:
                raise XMLParsingError("Empty or invalid XML content")

            try:
                parsed_data = XMLParser.parse(cleaned_text, required_tags=["thought", "answer"], optional_tags=["o"])
                return {"is_final": True, "thought": parsed_data.get("thought"), "answer": parsed_data.get("answer")}
            except TagNotFoundError:
                pass

            root = XMLParser._parse_with_lxml(cleaned_text)
            if root is None:

                wrapped_text = f"<root>{cleaned_text}</root>"
                root = XMLParser._parse_with_lxml(wrapped_text)

            if root is None:
                raise XMLParsingError("Failed to parse XML content with lxml")

            thought_elems = root.xpath(".//thought")
            if not thought_elems:
                raise TagNotFoundError("Required tag <thought> not found")

            thought = "".join(thought_elems[0].itertext()).strip()

            tool_calls_elem = None
            for tag_name in ["tool_calls", "tools"]:
                elems = root.xpath(f".//{tag_name}")
                if elems:
                    tool_calls_elem = elems[0]
                    break

            if tool_calls_elem is None:
                raise TagNotFoundError("Required tag <tool_calls> or <tools> not found")

            tools = []
            for tool_elem in tool_calls_elem.xpath(".//tool"):

                name_elem = None
                for name_tag in ["n", "name", "tool_name"]:
                    name_elems = tool_elem.xpath(f".//{name_tag}")
                    if name_elems:
                        name_elem = name_elems[0]
                        break

                if name_elem is None:
                    continue

                input_elem = None
                for input_tag in ["input", "tool_input"]:
                    input_elems = tool_elem.xpath(f".//{input_tag}")
                    if input_elems:
                        input_elem = input_elems[0]
                        break

                if input_elem is None:
                    continue

                tool_name = "".join(name_elem.itertext()).strip()
                input_json_str = "".join(input_elem.itertext()).strip()

                try:

                    input_json_str = re.sub(r"^```(?:json)?\s*|```$", "", input_json_str.strip())
                    tool_input = json.loads(input_json_str)
                except json.JSONDecodeError as e:
                    logger.warning(f"Failed to parse JSON for tool {tool_name}: {e}")
                    continue

                tools.append({"name": tool_name, "input": tool_input})

            if not tools:
                raise TagNotFoundError("No valid tool elements found with both name and input tags")

            return {"is_final": False, "thought": thought, "tools": tools}

        except (XMLParsingError, TagNotFoundError):
            raise
        except Exception as e:
            raise XMLParsingError(f"Error parsing XML: {str(e)}")


def create_message_from_input(input_data: dict) -> Message | VisionMessage:
    """
    Create appropriate message type based on input data,
    automatically detecting and handling images from either images or files fields

    Args:
        input_data (dict): Input data dictionary containing:
            - 'input': Text input string
            - 'images': List of image data (URLs, bytes, or BytesIO objects)
            - 'files': List of file data (bytes or BytesIO objects)

    Returns:
        Message or VisionMessage: Appropriate message type for the input
    """
    text_input = input_data.get("input", "")
    images = input_data.get("images", []) or []
    files = input_data.get("files", []) or []

    if not isinstance(images, list):
        images = [images]
    else:
        images = list(images)

    for file in files:
        if is_image_file(file):
            logger.debug(f"File detected as image, adding to vision processing: {getattr(file, 'name', 'unnamed')}")
            images.append(file)

    if not images:
        return Message(role=MessageRole.USER, content=text_input)

    content = []

    if text_input:
        content.append(VisionMessageTextContent(text=text_input))

    for image in images:
        try:
            if isinstance(image, str):
                if image.startswith(("http://", "https://", "data:")):
                    image_url = image
                else:
                    with open(image, "rb") as file:
                        image_bytes = file.read()
                        image_url = bytes_to_data_url(image_bytes)
            else:
                if isinstance(image, io.BytesIO):
                    image_bytes = image.getvalue()
                else:
                    image_bytes = image
                image_url = bytes_to_data_url(image_bytes)

            content.append(VisionMessageImageContent(image_url=VisionMessageImageURL(url=image_url)))
        except Exception as e:
            logger.error(f"Error processing image: {str(e)}")

    return VisionMessage(content=content, role=MessageRole.USER)


def is_image_file(file) -> bool:
    """
    Determine if a file is an image by examining its content

    Args:
        file: File-like object or bytes

    Returns:
        bool: True if the file is an image, False otherwise
    """
    try:
        if isinstance(file, io.BytesIO):
            pos = file.tell()
            file.seek(0)
            file_bytes = file.read(32)
            file.seek(pos)
        elif isinstance(file, bytes):
            file_bytes = file[:32]
        else:
            return False

        signatures = {
            b"\xff\xd8\xff": "jpg/jpeg",  # JPEG
            b"\x89PNG\r\n\x1a\n": "png",  # PNG
            b"GIF87a": "gif",  # GIF87a
            b"GIF89a": "gif",  # GIF89a
            b"RIFF": "webp",  # WebP
            b"MM\x00*": "tiff",  # TIFF (big endian)
            b"II*\x00": "tiff",  # TIFF (little endian)
            b"BM": "bmp",  # BMP
        }

        for sig, fmt in signatures.items():
            if file_bytes.startswith(sig):
                return True

        if isinstance(file, io.BytesIO):
            pos = file.tell()
            file.seek(0)
            mime = filetype.guess_mime(file.read(4096))
            file.seek(pos)
            return mime is not None and mime.startswith("image/")
        elif isinstance(file, bytes):
            mime = filetype.guess_mime(file)
            return mime is not None and mime.startswith("image/")

        return False
    except Exception as e:
        logger.error(f"Error checking if file is an image: {str(e)}")
        return False


def bytes_to_data_url(image_bytes: bytes) -> str:
    """
    Convert image bytes to a data URL

    Args:
        image_bytes (bytes): Raw image bytes

    Returns:
        str: Data URL string (format: data:image/jpeg;base64,...)
    """
    try:
        mime_type = filetype.guess_mime(image_bytes)
        if not mime_type:
            if image_bytes[:2] == b"\xff\xd8":
                mime_type = "image/jpeg"
            elif image_bytes[:8] == b"\x89PNG\r\n\x1a\n":
                mime_type = "image/png"
            elif image_bytes[:6] in (b"GIF87a", b"GIF89a"):
                mime_type = "image/gif"
            else:
                mime_type = "application/octet-stream"

        encoded = base64.b64encode(image_bytes).decode("utf-8")
        return f"data:{mime_type};base64,{encoded}"
    except Exception as e:
        logger.error(f"Error converting image to data URL: {str(e)}")
        raise ValueError(f"Failed to convert image to data URL: {str(e)}")


def process_tool_output_for_agent(content: Any, max_tokens: int = TOOL_MAX_TOKENS, truncate: bool = True) -> str:
    """
    Process tool output for agent consumption.

    This function converts various types of tool outputs into a string representation.
    It handles dictionaries (with or without a 'content' key), lists, tuples, and other
    types by converting them to a string. If the resulting string exceeds the maximum
    allowed length, it truncates the content.

    Args:
        content: The output from tool execution, which can be of various types.
        max_tokens: Maximum allowed token count for the content. The effective character
            limit is computed as max_tokens * CHARS_PER_TOKEN (assuming ~4 characters per token).
        truncate: Whether to truncate the content if it exceeds the maximum length.

    Returns:
        A processed string suitable for agent consumption.
    """
    if not isinstance(content, str):
        if isinstance(content, dict):
            filtered_content = {k: v for k, v in content.items() if k != "files"}

            if "content" in filtered_content:
                inner_content = filtered_content["content"]
                content = inner_content if isinstance(inner_content, str) else json.dumps(inner_content, indent=2)
            else:
                content = json.dumps(filtered_content, indent=2) if filtered_content else ""
        elif isinstance(content, (list, tuple)):
            content = "\n".join(str(item) for item in content)
        else:
            content = str(content)

    max_len_in_char: int = max_tokens * CHARS_PER_TOKEN
    content = re.sub(r"\{\{\s*(.*?)\s*\}\}", r"\1", content)

    if len(content) > max_len_in_char and truncate:
        half_length: int = (max_len_in_char - 100) // 2
        truncation_message: str = "\n...[Content truncated]...\n"
        content = content[:half_length] + truncation_message + content[-half_length:]

    return content


def extract_thought_from_intermediate_steps(intermediate_steps):
    """Extract thought process from the intermediate steps structure."""
    if not intermediate_steps:
        return None

    for step_key, step_value in intermediate_steps.items():
        if isinstance(step_value, dict) and "model_observation" in step_value:
            model_obs = step_value["model_observation"]

            if isinstance(model_obs, dict):
                if "initial" in model_obs:
                    initial = model_obs["initial"]

                    if initial.startswith("{") and '"thought"' in initial:
                        try:
                            json_data = json.loads(initial)
                            if "thought" in json_data:
                                return json_data["thought"]
                        except json.JSONDecodeError:
                            pass

                    if "<thought>" in initial:
                        thought_match = re.search(r"<thought>\s*(.*?)\s*</thought>", initial, re.DOTALL)
                        if thought_match:
                            return thought_match.group(1)

                    thought_match = re.search(r"Thought:\s*(.*?)(?:\n\n|\nAnswer:)", initial, re.DOTALL)
                    if thought_match:
                        return thought_match.group(1)

    return None


class ToolCacheEntry(BaseModel):
    """Single key entry in tool cache."""

    action: str
    action_input: dict | str

    model_config = ConfigDict(frozen=True)

    @model_validator(mode="before")
    @classmethod
    def convert_to_str(cls, data):
        if isinstance(data, dict):
            data["action_input"] = json.dumps(data.get('action_input'), sort_keys=True)
        return data


class SummarizationConfig(BaseModel):
    """Configuration for agent history summarization.

    Attributes:
        enabled (bool): Whether streaming is enabled. Defaults to False.
        max_token_context_length (int | None): Maximum number of tokens in prompt after
          which summarization will be applied. Defaults to None.
        context_usage_ratio (float): Relative percentage of tokens in prompt after which summarization will be applied.
        context_history_length (int): Number of history messages that will be prepended.
    """

    enabled: bool = False
    max_token_context_length: int | None = None
    context_usage_ratio: float = 0.8
    context_history_length: int = 4
